查看原文
其他

【新开源报道 40】OpenAI 开源最新工具包,十倍模型计算时间仅增20%

2018-01-20 开源最前线
程序猿(ID:imkuqin) 猿妹 编译

编译自:https://github.com/openai/gradient-checkpointing


近日,OpenAI 在 GitHub 上开源最新工具包 gradient-checkpointing,该工具包通过设置梯度检查点(gradient-checkpointing)来节省内存资源。据悉,对于普通的前馈模型,可以在计算时间只增加 20% 的情况下,在 GPU 上训练比之前大十多倍的模型。

工具包 gradient-checkpointing


授权协议:MIT

开发语言:Python

操作系统:跨平台

项目地址:https://github.com/openai/gradient-checkpointing



通过梯度检查点来节省内存资源


训练非常深的神经网络需要大量的内存。Tim Salimans和Yaroslav Bulatov联合开发的 gradient-checkpointing包中的工具,你可以通过计算来取消这些内存使用,从而使你的模型更好的存储训练。对于前馈模型,我们能够在计算时间仅增加20%的情况下,将10倍以上的大型模型放到我们的GPU上。


训练深度神经网络的记忆密集部分是通过反向传播计算损失的梯度。通过在模型训练时,定义计算图中检查点,并在反向传播期间重新计算这些节点之间的部分,可以降低的存储器成本时计算该梯度。



当训练由n层组成的深度前馈神经网络时,我们可以以这种方式将存储器消耗减少到O(sqrt(n)),代价就是执行一个额外的正向传递(参见例如以次线性内存成本训练深度网)。该存储库使用Tensorflow graph editor 功能来实现自动重写反向传递的计算图。



如何运行的


对于具有n层的简单前馈神经网络,用于获得梯度的计算图如下所示:



神经网络层的激活用f标记节点。在正向传递期间,所有这些节点按顺序进行计算。相对于层的激活和参数损失的梯度用b标记节点表示。在反向传递期间,所有这些节点都按照相反的顺序进行计算。对于f个节点的结果是计算b个节点所需要的,因此所有f节点在正向传递过程都会保存在存储器中。只有当反向传播已经进行到足以计算所有的依赖关系或子节点f,它才从内存中删除。下面我们显示这些节点的计算顺序。紫色阴影圆圈表示在任何给定时间哪个节点需要保存在内存中。



如上所述的计算方案是最佳的:它每个节点只计算一次。但是,如果我们想重新计算节点,那就可以简单地重新计算每个节点的正向传递,这样能够节省大量内存,执行顺序和使用的内存如下所示:



使用这个策略,计算梯度所需的存储器在神经网络层数n中是恒定的,这在存储器方面是最佳的。但是,节点评估的数目现在按n ^ 2进行缩放,而之前将其缩放为n:n个节点,每一个节点都按n次的顺序重新计算。因此,对于深度网络来说,计算时间太长,这使得该方法在深度学习中不切实际。


为了在内存和计算之间取得平衡,我们需要提出一个允许节点重新计算的策略,将神经网络激活的一个子集标记为检查节点。



这些检查点节点在正向传递后保留在内存中,而其余节点至多重新计算一次。重新计算后,非检查点节点将保存在内存中,直到不再 42 35531 42 15231 0 0 4150 0 0:00:08 0:00:03 0:00:05 4150需要它们。


对于简单前馈神经网络的情况,所有的神经元激活节点都是由正向通道定义的图形的图分隔符或关节点。这意味着我们只需要在backprop期间计算b节点时,重新计算b节点和它之前的最后一个检查点之间的节点。当backprop进展到足以到达检查点节点时,所有从它重新计算的节点都可以从内存中删除。计算和内存使用的结果顺序如下所示



安装要求



当执行这一程序时,需要保证能找到CUPTI。这时可以运行 export LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:/usr/local/cuda/extras/CUPTI/lib64"



用法


这个库提供嵌入式功能,tf.gradients的替代方案,可以输入如下程序来引入:from memory_saving_gradients import gradients并且使用 tf.gradients 函数一样使用 gradients 函数来计算参数损失的梯度。


gradients 函数有一个额外的功能——检查点(checkpoints)。检查点会对 gradients 函数进行指示——在计算图的前向传播中,图中的哪一部分节点是用户想要检查的点。随后,会在后向传播中重新计算检查点之间的节点。


覆盖 tf.gradients 函数,使用 gradients 函数的另一个方法是直接覆盖 tf.gradients 函数,方法如下:



这样操作之后,所有调用 tf.gradients 函数的请求都会使用新的节省内存的方法。



测试


测试文件夹包含脚本,用于测试代码的正确性,并分析各种模型的内存使用情况。修改代码后,您可以./run_all_tests.sh从该文件夹运行以执行测试。



图:在CIFAR10数据集上,使用常规的梯度函数和使用最新的优化内存函数,在不同层数的 ResNet 网络下的内存占用情况和执行时间的对比


附:新开源报道汇总

《【新开源报道 39】Mozilla 开源稍后阅读应用 Pocket 代码》

《【新开源报道 38】那个被美国通缉的程序员,开发了一款保护你隐私的 App》

《【新开源报道 37】Facebook 开源语音识别工具包wav2letter》

《【新开源报道 36】有效减少错误代码!Instagram 开源用于 Python 3 的 MonkeyType 工具》

《【新开源报道 35】国内第一家私有视频通信软件 Tucodec 开源》

《【新开源报道 34】AMD 开源基于 Mesa 的 Vulkan Linux 驱动》

《【新开源报道 33】安全软件公司 Avast 开源机器码反编译器 RetDec》

《【新开源报道 32】谷歌开源 TFGAN:轻量级生成对抗网络工具库》

《【新开源报道 31】国内首套开源持续集成(CI) 解决方案 flow.ci 开源啦》

《【开源推荐 30】苹果开源领域又一深造:开源机器学习框架 Turi Create》

《【新开源推荐 29】AI开发者的福音:360公司宣布开源深度学习调度平台 XLearning!》

《【新开源推荐 28】百度正式开源 Linux 发行版 MesaLock Linux》

《【新开源报道 27】百度开源高性能 Python 分布式计算框架 Bigflow》

《【新开源报道 26】滴滴开源基于 Vue.js 的移动端组件库 cube-ui》

《【新开源报道 25】Google 开源 Docker 镜像差异分析工具 container-diff》

《【新开源报道 24】美团点评开源MySQL闪回工具 —— MyFlash》

《【新开源报道 23】IBM 推出首套开源现代化字体 —— IBM Plex》

《【新开源报道 22】Microsoft 开源用于 VS Code 的 Java Debugger》

《【新开源报道 21】阿里开源容器技术Pouch和P2P文件分发系统“蜻蜓”》

《【新开源报道 20】Uber正式开源其分布式跟踪系统Jaeger》

《【新开源报道 19】Uber与斯坦福大学开源深度概率编程语言Pyro》

《【新开源报道 18】谷歌开放内部工具 Colaboratory 来协助 AI 开发》

《【新开源报道 17】这波开源满分!清华大学开源网络嵌入的工具包 —— OpenNE

《【新开源报道 16】AI开发者福音!微软亚马逊联合发布深度学习库 Gloun》

《【新开源报道 15】谷歌发布量子开源软件,量子计算机对科学家免费开放 》

《【新开元报道 14】微软开源用于Spark的深度学习库MMLSpark》

《【新开源报道 13】Facebook 开源帮助开发者消灭最顽固的软件 bug 的工具》

《【新开源报道 12】不只是阿里巴巴的操作系统,AliOS 宣布开源》

《【新开源报道 11】重磅!阿里巴巴正式开源全球化OpenMessaging和ApsaraCache项目》

《【新开源报道 10】IBM 和谷歌等巨头联手为开发者推出开源容器安全工具Grafeas》

《【新开源报道 9】Google开源Abseil,为C++和Python开发提供支持》

《【新开源报道 8】serverless 领域的福音!Oracle 宣布开源 Fn project》

《【新开源报道 7】苹果在 GitHub 上公布 macOS 和 iOS 内核源码》

《【新开源报道 6】百度开源移动端深度学习框架mobile-deep-learning(MDL)》

《【新开源报道 5】百度正式开源其 RPC 框架 brpc》

《【新开源报道 4】IBM 开源动态的应用服务器运行时环境 Open Liberty》

《【新开源报道 3】微信后台团队最近开源力作:PhxQueue分布式队列》

《【新开源报道 2】喜大普奔!阿里即将开源 ApsaraCache,云数据库 Redis 版分支》

【新开源报道 1】腾讯 Web UI 解决方案 QMUI Web 正式回迁开源》



●本文编号147,以后想阅读这篇文章直接输入147即可

●输入m获取文章目录

↓↓↓ 点击"阅读原文" 进入GitHub详情页 

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存